{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data splitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.data import SplitType, make_split_indices, split_data_by_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These are example [datapoints](./datapoints.ipynb) to split."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from chemprop.data import MoleculeDatapoint\n",
    "\n",
    "smis = [\"C\" * i for i in range(1, 11)]\n",
    "ys = np.random.rand(len(smis), 1)\n",
    "datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data splits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A typical Chemprop workflow uses three sets of data. The first is used to train the model. The second is used as validation for early stopping and hyperparameter optimization. The third is used to test the final model's performance as an estimate for how it will perform on future data. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chemprop provides helper functions to split data into these training, validation, and test sets. Available splitting schemes are listed in `SplitType`.\n",
    "All of these rely on [`astartes`](https://github.com/JacksonBurns/astartes) in the backend."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scaffold_balanced\n",
      "random_with_repeated_smiles\n",
      "random\n",
      "kennard_stone\n",
      "kmeans\n"
     ]
    }
   ],
   "source": [
    "for splittype in SplitType:\n",
    "    print(splittype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Splitting steps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Collect the `rdkit.Chem.mol` objects for each datapoint. These are required for structure based splits.\n",
    "2. Generate the splitting indices.\n",
    "3. Split the data using those indices.\n",
    "\n",
    "The `make_split_indices` function includes a `num_replicates` argument to perform repeated splits (each with a different random seed) with your sampler of choice.\n",
    "Any sampler can be used for replicates, though deterministic samplers (i.e. Kennard-Stone) will not change on replicates.\n",
    "Splits are returned as a 2- or 3-member tuple containing `num_replicates`-length lists of training, validation, and testing indexes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "mols = [d.mol for d in datapoints]\n",
    "\n",
    "train_indices, val_indices, test_indices = make_split_indices(mols)\n",
    "\n",
    "train_data, val_data, test_data = split_data_by_indices(\n",
    "    datapoints, train_indices, val_indices, test_indices\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The default splitting scheme is a random split with 80% of the data used to train, 10% to validate and 10% to split."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 1, 1)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data), len(val_data), len(test_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each of these is length 1 because we only requested 1 replicate (the default).\n",
    "The inner lists for each of these sets contain the actual indices for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8, 1, 1)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data[0]), len(val_data[0]), len(test_data[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Split randomness"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All split randomness uses a default seed of 0 and `numpy.random`. The seed can be changed to get different splits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([[8, 4, 9, 1, 6, 7, 3, 0]], [[5]], [[2]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_split_indices(datapoints)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([[8, 7, 0, 4, 9, 3, 2, 1]], [[6]], [[5]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_split_indices(datapoints, seed=12)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Split fractions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The split sizes can also be changed. Set the middle value to 0 for a two way split. If the data can not be split to exactly the specified proportions, you will get a warning from `astartes` with the actual sizes used. And if the specified sizes don't sum to 1, the sizes will first be rescaled to sum to 1. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([[8, 4, 9, 1]], [[6, 7, 3]], [[0, 5, 2]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_split_indices(datapoints, sizes=(0.4, 0.3, 0.3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([[8, 4, 9, 1, 6, 7]], [[]], [[3, 0, 5, 2]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_split_indices(datapoints, sizes=(0.6, 0.0, 0.4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/astartes/main.py:325: ImperfectSplittingWarning: Actual train/test split differs from requested size. Requested validation size of 0.25, got 0.30. Requested test size of 0.25, got 0.30. \n",
      "  warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([[8, 4, 9, 1, 6]], [[7, 3]], [[0, 5, 2]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_split_indices(datapoints, sizes=(0.5, 0.25, 0.25))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/astartes/main.py:381: NormalizationWarning: Requested train/val/test split (0.50, 0.50, 0.50) do not sum to 1.0, normalizing to train=0.33, val=0.33, test=0.33.\n",
      "  warn(\n",
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/astartes/main.py:325: ImperfectSplittingWarning: Actual train/test split differs from requested size. Requested train size of 0.33, got 0.30. Requested test size of 0.33, got 0.20. \n",
      "  warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([[8, 4, 9]], [[1, 6, 7, 3, 0]], [[5, 2]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_split_indices(datapoints, sizes=(0.5, 0.5, 0.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Random with repeated molecules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If your dataset has repeated molecules, all duplicate molecules should go in the same split. This split type requires the `rdkit.Chem.mol` objects of the datapoints. It first removes duplicates before using `astartes` to make the random splits and then adds back in the duplicate datapoints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "smis = [\"O\", \"O\"] + [\"C\" * i for i in range(1, 10)]\n",
    "ys = np.random.rand(len(smis), 1)\n",
    "repeat_datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]\n",
    "mols = [d.mol for d in repeat_datapoints]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([[10, 6, 0, 1, 3, 8, 9, 5, 2]], [[7]], [[4]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_split_indices(mols, split=\"random_with_repeated_smiles\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Structure based splits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Including all similar molecules in only one of the datasets can give a more realistic estimate of how a model will perform on unseen chemistry. This uses the `rdkit.Chem.mol` representation of the molecules. See the `astartes` [documentation](https://jacksonburns.github.io/astartes/) for details about Kennard Stone, k-means, and scaffold balanced splitting schemes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "smis = [\n",
    "    \"Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14\",\n",
    "    \"COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23\",\n",
    "    \"COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl\",\n",
    "    \"OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3\",\n",
    "    \"Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1\",\n",
    "    \"OC1(CN2CCC1CC2)C#Cc3ccc(cc3)c4ccccc4\",\n",
    "    \"COc1cc(OC)c(cc1NC(=O)CCC(=O)O)S(=O)(=O)NCc2ccccc2N3CCCCC3\",\n",
    "    \"CNc1cccc(CCOc2ccc(C[C@H](NC(=O)c3c(Cl)cccc3Cl)C(=O)O)cc2C)n1\",\n",
    "    \"COc1ccc(cc1)C2=COc3cc(OC)cc(OC)c3C2=O\",\n",
    "    \"Oc1ncnc2scc(c3ccsc3)c12\",\n",
    "    \"CS(=O)(=O)c1ccc(Oc2ccc(cc2)C#C[C@]3(O)CN4CCC3CC4)cc1\",\n",
    "    \"C[C@H](Nc1nc(Nc2cc(C)[nH]n2)c(C)nc1C#N)c3ccc(F)cn3\",\n",
    "    \"O=C1CCCCCN1\",\n",
    "    \"CCCSc1ncccc1C(=O)N2CCCC2c3ccncc3\",\n",
    "    \"CC1CCCCC1NC(=O)c2cnn(c2NS(=O)(=O)c3ccc(C)cc3)c4ccccc4\",\n",
    "    \"Nc1ccc(cc1)c2nc3ccc(O)cc3s2\",\n",
    "    \"COc1ccc(cc1)N2CCN(CC2)C(=O)[C@@H]3CCCC[C@H]3C(=O)NCC#N\",\n",
    "    \"CCC(COC(=O)c1cc(OC)c(OC)c(OC)c1)(N(C)C)c2ccccc2\",\n",
    "    \"COc1cc(ccc1N2CC[C@@H](O)C2)N3N=Nc4cc(sc4C3=O)c5ccc(Cl)cc5\",\n",
    "    \"CO[C@H]1CN(CCN2C(=O)C=Cc3ccc(cc23)C#N)CC[C@H]1NCc4ccc5OCC(=O)Nc5n4\",\n",
    "]\n",
    "\n",
    "ys = np.random.rand(len(smis), 1)\n",
    "datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]\n",
    "mols = [d.mol for d in datapoints]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/astartes/main.py:325: ImperfectSplittingWarning: Actual train/test split differs from requested size. Requested train size of 0.80, got 0.85. Requested test size of 0.10, got 0.05. \n",
      "  warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([[0, 1, 2, 3, 4, 6, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19]],\n",
       " [[5, 10]],\n",
       " [[7]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_split_indices(mols, split=\"kmeans\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
